{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matpltlib as mpl\n",
    "import numpy as np\n",
    "def discrete_scatter(x1, x2, y=None, markers=None, s=10, ax=None,\n",
    "                     labels=None, padding=.2, alpha=1, c=None, markeredgewidth=None):\n",
    "    \"\"\"Adaption of matplotlib.pyplot.scatter to plot classes or clusters.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "\n",
    "    x1 : nd-array\n",
    "        input data, first axis\n",
    "\n",
    "    x2 : nd-array\n",
    "        input data, second axis\n",
    "\n",
    "    y : nd-array\n",
    "        input data, discrete labels\n",
    "\n",
    "    cmap : colormap\n",
    "        Colormap to use.\n",
    "\n",
    "    markers : list of string\n",
    "        List of markers to use, or None (which defaults to 'o').\n",
    "\n",
    "    s : int or float\n",
    "        Size of the marker\n",
    "\n",
    "    padding : float\n",
    "        Fraction of the dataset range to use for padding the axes.\n",
    "\n",
    "    alpha : float\n",
    "        Alpha value for all points.\n",
    "    \"\"\"\n",
    "    if ax is None:\n",
    "        ax = plt.gca()\n",
    "\n",
    "    if y is None:\n",
    "        y = np.zeros(len(x1))\n",
    "\n",
    "    unique_y = np.unique(y)\n",
    "\n",
    "    if markers is None:\n",
    "        markers = ['o', '^', 'v', 'D', 's', '*', 'p', 'h', 'H', '8', '<', '>'] * 10\n",
    "\n",
    "    if len(markers) == 1:\n",
    "        markers = markers * len(unique_y)\n",
    "\n",
    "    if labels is None:\n",
    "        labels = unique_y\n",
    "\n",
    "    # lines in the matplotlib sense, not actual lines\n",
    "    lines = []\n",
    "\n",
    "    current_cycler = mpl.rcParams['axes.prop_cycle']\n",
    "\n",
    "    for i, (yy, cycle) in enumerate(zip(unique_y, current_cycler())):\n",
    "        mask = y == yy\n",
    "        # if c is none, use color cycle\n",
    "        if c is None:\n",
    "            color = cycle['color']\n",
    "        elif len(c) > 1:\n",
    "            color = c[i]\n",
    "        else:\n",
    "            color = c\n",
    "        # use light edge for dark markers\n",
    "        if np.mean(colorConverter.to_rgb(color)) < .4:\n",
    "            markeredgecolor = \"grey\"\n",
    "        else:\n",
    "            markeredgecolor = \"black\"\n",
    "\n",
    "        lines.append(ax.plot(x1[mask], x2[mask], markers[i], markersize=s,\n",
    "                             label=labels[i], alpha=alpha, c=color,\n",
    "                             markeredgewidth=markeredgewidth,\n",
    "                             markeredgecolor=markeredgecolor)[0])\n",
    "\n",
    "    if padding != 0:\n",
    "        pad1 = x1.std() * padding\n",
    "        pad2 = x2.std() * padding\n",
    "        xlim = ax.get_xlim()\n",
    "        ylim = ax.get_ylim()\n",
    "        ax.set_xlim(min(x1.min() - pad1, xlim[0]), max(x1.max() + pad1, xlim[1]))\n",
    "        ax.set_ylim(min(x2.min() - pad2, ylim[0]), max(x2.max() + pad2, ylim[1]))\n",
    "\n",
    "    return lines"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
